import argparse
import torch
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
from torchvision.utils import save_image
from PIL import Image
import neptune
import time
import argparse
import torch
import os
import json
from tqdm import tqdm
from typing import Dict, Optional, Sequence, List
from PIL import Image
# ----------------- Imports -----------------
from llava_v15.model.builder import load_pretrained_model
from llava_v15.utils import disable_torch_init
from llava_v15.mm_utils import tokenizer_image_token, get_model_name_from_path
from llava_v15.constants import DEFAULT_IMAGE_TOKEN, IMAGE_TOKEN_INDEX
from llava_v15.conversation import conv_templates
from torchvision import transforms


def parse_args():

    parser = argparse.ArgumentParser(description="Demo")
    parser.add_argument("--gpu_id", type=int, default=2, help="specify the gpu to load the model.")
    parser.add_argument("--n_iters", type=int, default=5000, help="specify the number of iterations for attack.")
    parser.add_argument("--batch_size", type=int, default=1, help="specify the batch size for imagenet.")
    parser.add_argument('--eps', type=int, default=32, help="epsilon of the attack budget")
    parser.add_argument('--alpha', type=int, default=1, help="step_size of the attack")
    parser.add_argument("--save_dir", type=str, default='defalut',
                        help="save directory")
    parser.add_argument("--conv-mode", type=str, default="llava_v1")
    parser.add_argument("--temperature", type=float, default=0.2)
    parser.add_argument("--top_p", type=float, default=None)
    parser.add_argument("--num_beams", type=int, default=1)
    parser.add_argument('--select', type=int, default=1, help="step_size of the attack")
    parser.add_argument("--ours", default=False, action='store_true') 
    parser.add_argument('--th', type=int, default=0.2, help="tau value")
    
    args = parser.parse_args()
    return args

args = parse_args()
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.gpu_id)

def load_image(image_path):
    image = Image.open(image_path).convert('RGB')
    return image

# ========================================
#             Model Initialization
# ========================================



print('>>> Initializing Model')
import warnings
warnings.filterwarnings("ignore")
# ----------------- Model Load -----------------
pretrained = "liuhaotian/llava-v1.5-13b"
model_name = "llava_v1.5"
tokenizer, model, processor, context_len = load_pretrained_model(
    model_path=pretrained,
    model_base=None,
    model_name=model_name,
    device_map="auto"
)
model.eval()

print('[Initialization Finished]\n')


if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)

import csv

file = open("harmful_corpus/derogatory_corpus.csv", "r") 
data = list(csv.reader(file, delimiter=","))
file.close()
targets = []
num = len(data)
for i in range(num):
    targets.append(data[i][0])

    
print('device = ', model.device)

template_img = 'adversarial_images/n02510455_405.JPEG'
image = load_image(template_img)
image = processor.preprocess(image, return_tensors='pt')['pixel_values'].cuda()
print(image.shape)

from llava_v15_utils import visual_attacker_llava_15 as visual_attacker



            
my_attacker = visual_attacker.Attacker(args, model, tokenizer, targets, device=model.device, image_processor=processor, run=None)        


print(args.save_dir)
if not os.path.exists(args.save_dir):
    os.mkdir(args.save_dir)


adv_img_prompt = my_attacker.targeted_attack_B2H(img = image,
                                                batch_size= args.batch_size,
                                                num_iter=args.n_iters, 
                                                alpha=args.alpha / 255,
                                                epsilon=args.eps / 255,
                                                ours = args.ours,
                                                )

print('[Done]')


